from dataclasses import dataclass
from typing import Union, Optional, List, Tuple, Dict, Sequence
from tools.codegen.api.translate import translate
from tools.codegen.model import NativeFunctionsGroup, ScalarType, UfuncKey, DispatchKey, BaseType, BaseTy, Argument
import tools.codegen.api.ufunc as ufunc
from tools.codegen.api.ufunc import UfunctorBindings
from tools.codegen.api.types import (
    StructuredImplSignature, scalar_t, opmath_t, Binding, CType,
    BaseCType, Expr, NamedCType, ScalarTypeToCppMapping, VectorizedCType
)
from tools.codegen.context import with_native_function

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
#                                  CUDA STUFF
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

# NB: not bothering to generate dispatch stub forward declaration in header,
# we can just paste it whereever necessary

# TODO: use BackendIndex
# dispatch_key: DispatchKey  # only CPU/CUDA right now


# Represents functors for implementing CUDA ufuncs.
# Functors are templated by scalar_t because when USERS instantiate functors
# they are templated.  A functor looks something like this:
#
#   template <typename scalar_t>
#   struct CUDAFunctorOnSelf_add {
#     using opmath_t = at::opmath_type<scalar_t>;
#     opmath_t other_;
#     opmath_t alpha_;
#     CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
#         : other_(other), alpha_(alpha) {}
#     __device__ scalar_t operator()(scalar_t self) {
#       return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
#     }
#   };
#
@dataclass(frozen=True)
class UfunctorSignature:
    g: NativeFunctionsGroup
    scalar_tensor_idx: Optional[int]
    name: str

    def arguments(self) -> UfunctorBindings:
        return ufunc.ufunctor_arguments(self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t)

    def fields(self) -> List[Binding]:
        # fields are renamed to have a trailing underscore, as is conventional
        return [b.rename(f"{b.name}_") for b in self.arguments().ctor]

    def returns_type(self) -> CType:
        # TODO: don't hardcode; return type will be inferred based on tags on
        # the native function
        return BaseCType(scalar_t)

    def decl_fields(self) -> str:
        return "\n".join(f"{f.type} {f.name};" for f in self.fields())

    def inline_defn_ctor(self) -> str:
        args_str = ', '.join(a.decl() for a in self.arguments().ctor)
        # NB: hypothetically could do this with translate but the
        # transition here is very regular
        init_str = ', '.join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
        return f"{self.name}({args_str}) : {init_str} {{}}"

    def decl_apply(self) -> str:
        args_str = ', '.join(a.decl() for a in self.arguments().apply)
        return f"{self.returns_type().cpp_type()} operator()({args_str}) const"


@dataclass(frozen=True)
class UfuncSignature:
    g: NativeFunctionsGroup
    name: str
    compute_t: CType

    def arguments(self) -> List[Binding]:
        return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)

    def call(self, ctx: Sequence[Union[Binding, Expr]]) -> str:
        return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"


# steps:
#   1. take the functional signature
#   2. use api.ufunc to convert it to template signature.  this establishes
#      the type of the template function
#   3. use api.ufunc (II) to generate a split struct / operator() signature.
#      this establish context in which we call the template signature
#
# StructuredImplSignature context
#   ~> functor constructor sig
#
# Functor constructor context
#   ~> functor fields sig
#
# Functor apply context (functor fields + functor apply sig)
#   ~> template sig
#

def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
    num_tensors = sum(1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like())
    return num_tensors == 2

def compute_ufunc_cuda_functors(g: NativeFunctionsGroup) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]:
    # First, build the functors.
    ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {}
    ufunctors: List[str] = []
    loops = g.out.ufunc_inner_loop
    scalar_tensor_idx_lookup = {
        UfuncKey.CUDAFunctorOnSelf: 1,
        UfuncKey.CUDAFunctorOnOther: 0,
        UfuncKey.CUDAFunctor: None
    }
    if eligible_for_binary_scalar_specialization(g):
        keys = [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther, UfuncKey.CUDAFunctor]
    else:
        keys = [UfuncKey.CUDAFunctor]
        for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
            assert k not in loops, f"cannot use {k} on non-binary function"
    for k in keys:
        # If the key was directly defined, skip functor codegen; we assume the
        # user already done it for us
        if k in loops:
            ufunctor_sig = UfunctorSignature(g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name)
            for dtype in loops[k].supported_dtypes:
                ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
            continue

        # Note [ScalarOnly and Generic must match names for CUDA]
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Otherwise, look in ANY of the generic entries.  For simplicity of
        # codegen, both ScalarOnly and Generic are defined, the ufunc name
        # must match  (if they didn't match, we'd have to generate distinct
        # functors per dtype, which is awful, so we're not going to do it unless
        # someone really forces us to)
        ufunc_name = None
        supported_dtypes = set()
        for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
            if lk not in loops:
                continue
            if ufunc_name is None:
                ufunc_name = loops[lk].name
            else:
                # See Note [ScalarOnly and Generic must match names for CUDA]
                assert ufunc_name == loops[lk].name, "ScalarOnly and Generic must have same ufunc name"
            supported_dtypes |= loops[lk].supported_dtypes
        assert ufunc_name is not None

        name = f"{k}_{ufunc_name}"
        ufunctor_sig = UfunctorSignature(g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name)
        for dtype in supported_dtypes:
            ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig

        ufunc_sig = UfuncSignature(g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t))
        apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
        ufunctors.append(f"""
template <typename scalar_t>
struct {ufunctor_sig.name} {{
  using opmath_t = at::opmath_type<scalar_t>;
  {ufunctor_sig.decl_fields()}
  {ufunctor_sig.inline_defn_ctor()}
  __device__ {ufunctor_sig.decl_apply()} {{
    return {ufunc_sig.call(apply_ctx)};
  }}
}};
""")

    return ufunctor_sigs, "\n".join(ufunctors)

@dataclass(frozen=True)
class BinaryScalarSpecializationConfig:
    scalar_idx: int
    ctor_tensor: str
    ufunc_key: UfuncKey

BinaryScalarSpecializationConfigs = [
    BinaryScalarSpecializationConfig(
        scalar_idx=0,
        ctor_tensor='self',
        ufunc_key=UfuncKey.CUDAFunctorOnOther,
    ),
    BinaryScalarSpecializationConfig(
        scalar_idx=1,
        ctor_tensor='other',
        ufunc_key=UfuncKey.CUDAFunctorOnSelf,
    ),
]

def compute_ufunc_cuda_dtype_body(
    g: NativeFunctionsGroup, dtype: ScalarType,
    inner_loops: Dict[UfuncKey, UfunctorSignature], parent_ctx: Sequence[Binding]
) -> str:
    body = "using opmath_t = at::opmath_type<scalar_t>;"
    body += "if (false) {}\n"  # for ease of codegen
    for config in BinaryScalarSpecializationConfigs:
        if config.ufunc_key not in inner_loops:
            continue
        ufunctor_sig = inner_loops[config.ufunc_key]
        scalar_idx = config.scalar_idx + 1
        # Make a copy and at the same time widen the type (not permissible
        # without copy; we don't want to mutate the input argument anyway)
        ctx: List[Union[Expr, Binding]] = list(parent_ctx)
        ctx.append(Expr(
            expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
            type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
        ))
        ufunctor_ctor_exprs_str = ', '.join(a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor))

        # NB: ufunctor must be allocated before iter.remove_operand is called,
        # as it relies on iter
        body += f"""\
else if (iter.is_cpu_scalar({scalar_idx})) {{
  {ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str});
  iter.remove_operand({scalar_idx});
  gpu_kernel(iter, ufunctor);
}}"""

    ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
    ufunctor_ctor_exprs_str = ', '.join(a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor))
    body += f"""
else {{
  gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str}));
}}
    """
    return body

@with_native_function
def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
    # First, build the functors, indexing them by dtype
    ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)

    # Next, build the conditionals
    sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
    dtype_cases = []
    for dtype, inner_ufunctor_sigs in ufunctor_sigs.items():
        dtype_cases.append(f"""
AT_PRIVATE_CASE_TYPE("{sig.name}", at::ScalarType::{dtype}, {ScalarTypeToCppMapping[dtype]},
  [&]() {{
    {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunctor_sigs, sig.arguments())}
  }}
)
""")

    dtype_cases_str = "\n".join(dtype_cases)

    stub_sig = StubSignature(g)

    return f"""
{ufunctors}

{stub_sig.type_defn()};
{stub_sig.dispatch_decl()};

{stub_sig.kernel_defn()} {{
  at::ScalarType st = iter.common_dtype();
  RECORD_KERNEL_FUNCTION_DTYPE("{sig.name}", st);
  switch (st) {{
    {dtype_cases_str}
    default:
      TORCH_CHECK(false, "{sig.name}", " not implemented for '", toString(st), "'");
  }}
}}
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});

{sig.defn()} {{
  {stub_sig.direct_call(sig.arguments())};
}}
"""

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
#                                   CPU STUFF
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

@dataclass(frozen=True)
class StubSignature:
    g: NativeFunctionsGroup

    @property
    def name(self) -> str:
        return f"{str(self.g.functional.func.name.name)}_stub"

    @property
    def kernel_name(self) -> str:
        return f"{str(self.g.functional.func.name.name)}_kernel"

    @property
    def type_name(self) -> str:
        return f"{str(self.g.functional.func.name.name)}_fn"

    def arguments(self) -> List[Binding]:
        return ufunc.stub_arguments(self.g)

    def type(self) -> str:
        cpp_args = self.arguments()
        return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"

    def dispatch_decl(self) -> str:
        return f"DECLARE_DISPATCH({self.type_name}, {self.name})"

    def dispatch_defn(self) -> str:
        return f"DEFINE_DISPATCH({self.name})"

    def kernel_defn(self) -> str:
        return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"

    def type_defn(self) -> str:
        return f"using {self.type_name} = {self.type()}"

    # must be called from context where this is TensorIteratorBase*
    def call(self, ctx: Sequence[Binding]) -> str:
        return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"

    # used in CUDA to skip the unnecessary dynamic dispatch
    def direct_call(self, ctx: Sequence[Binding]) -> str:
        return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"

@with_native_function
def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
    stub_sig = StubSignature(g)
    sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))

    return f"""
{stub_sig.type_defn()};
{stub_sig.dispatch_decl()};
{stub_sig.dispatch_defn()};

{sig.defn()} {{
  {stub_sig.call(sig.arguments())};
}}
"""

def compute_ufunc_cpu_dtype_body(
    g: NativeFunctionsGroup, dtype: ScalarType, inner_loops: Dict[UfuncKey, UfuncSignature],
    parent_ctx: Sequence[Binding]
) -> str:
    assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
    assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
    scalar_loop = inner_loops[UfuncKey.CPUScalar]
    vec_loop = None
    if UfuncKey.CPUVector in inner_loops:
        vec_loop = inner_loops[UfuncKey.CPUVector]

    # NB: We DON'T use translate here, because translate is
    # incapable of CSE'ing the scalar accesses in case it is also
    # used by Vectorized; also, the unpacking here is very simple
    # and only affects Scalar; everything else is implicitly captured
    # by the lambda

    # Setup scalar in scope
    body = []
    ctx = []
    for b in parent_ctx:
        if isinstance(b.argument, Argument) and b.argument.type != BaseType(BaseTy.Scalar):
            continue
        body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();")
        ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
    if vec_loop is not None:
        for b in parent_ctx:
            if isinstance(b.argument, Argument) and b.argument.type != BaseType(BaseTy.Scalar):
                continue
            body.append(f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});")
            ctx.append(Expr(f"_v_{b.name}", NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t)))))

    # Setup lambda signature
    # NB: simplified version of ufunctor_arguments
    scalar_bindings = []
    vec_bindings = []
    for a in g.functional.func.arguments.flat_non_out:
        if not a.type.is_tensor_like():
            continue
        assert a.type == BaseType(BaseTy.Tensor)
        scalar_bindings.append(Binding(
            name=a.name,
            nctype=NamedCType(a.name, BaseCType(scalar_t)),
            argument=a,
        ))
        if vec_loop is not None:
            vec_bindings.append(Binding(
                name=a.name,
                nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
                argument=a,
            ))

    def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]:
        r: List[Union[Expr, Binding]] = []
        r.extend(ctx)
        r.extend(b)
        return r

    body_str = '\n'.join(body)
    if vec_loop is not None:
        return f"""
{body_str}
cpu_kernel_vec(iter,
  [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
  [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
);
"""
    else:
        return f"""
{body_str}
cpu_kernel(iter,
  [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
);
"""

@with_native_function
def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
    stub_sig = StubSignature(g)

    # Reindex the ufunc by dtypes; processing generic/scalaronly as well
    loops = g.out.ufunc_inner_loop
    ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {}
    for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
        lks = []
        # ORDER MATTERS: this specifies overriding precedence
        if k in loops:  # should happen rarely
            lks.append(k)
        if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
            lks.append(UfuncKey.ScalarOnly)
        if UfuncKey.Generic in loops:
            lks.append(UfuncKey.Generic)
        # TODO: don't hardcode ufunc:: namespace here, should be centralized smh
        for lk in lks:
            for dtype in loops[lk].supported_dtypes:
                compute_t: CType
                if k is UfuncKey.CPUScalar:
                    compute_t = BaseCType(scalar_t)
                elif k is UfuncKey.CPUVector:
                    compute_t = VectorizedCType(BaseCType(scalar_t))
                else:
                    raise AssertionError()
                inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
                if k not in inner_ufunc_sigs:
                    inner_ufunc_sigs[k] = UfuncSignature(
                        g, name=f"ufunc::{loops[lk].name}",
                        compute_t=compute_t
                    )

    # Build the conditionals
    dtype_cases = []
    for dtype, inner_ufunc_sigs in ufunc_sigs.items():
        dtype_cases.append(f"""
AT_PRIVATE_CASE_TYPE("{stub_sig.name}", at::ScalarType::{dtype}, {ScalarTypeToCppMapping[dtype]},
  [&]() {{
    {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
  }}
)
""")

    dtype_cases_str = "\n".join(dtype_cases)
    return f"""
namespace {{

{stub_sig.kernel_defn()} {{
  at::ScalarType st = iter.common_dtype();
  RECORD_KERNEL_FUNCTION_DTYPE("{stub_sig.name}", st);
  switch (st) {{
    {dtype_cases_str}
    default:
      TORCH_CHECK(false, "{stub_sig.name}", " not implemented for '", toString(st), "'");
  }}
}}

}} // anonymous namespace

{stub_sig.type_defn()};
{stub_sig.dispatch_decl()};
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
"""
